/*
 * Copyright 2017-2018 NXP
 * All rights reserved.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "dct2d_powerquad.h"

#include "fsl_powerquad.h"
#include <stdio.h>

/*******************************************************************************
 * Definition.
 ******************************************************************************/
#define POWERQUAD_PRIVATE_RAM_BASE  0xE0000000

/*******************************************************************************
 * Variables.
 ******************************************************************************/
uint32_t powerquad_matrix_inv_tmp[1024];
float dct2d_A_matrix_f32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_A_inv_matrix_f32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
//float dct2d_input_matrix[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_output_matrix_f32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];
float dct2d_tmp_matrix_f32[DCT2D_MATRIX_N][DCT2D_MATRIX_N];

/*******************************************************************************
 * Decleration.
 ******************************************************************************/
float powerquad_inv(float input);
float powerquad_cos(float input);
float powerquad_sin(float input);
float powerquad_sqrt(float input);

/*******************************************************************************
 * Code
 ******************************************************************************/
void dct2d_powerquad_init(void)
{
    //printf("\r\n%s\r\n", __func__);
    /* powerquad. */
    PQ_Init(POWERQUAD);

    //dct2d_powerquad_calc_matrix_A();
    uint32_t i, j;
    float factor1, factor2;
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);

    //timer_start();

    /* 生成转换矩阵A. */
    factor2 = powerquad_inv((float)DCT2D_MATRIX_N);
    factor1 = powerquad_sqrt(factor2);
    for (j = 0u; j < DCT2D_MATRIX_N; j++)
    {
        dct2d_A_matrix_f32[0][j] = factor1;
    }

    factor1 = powerquad_sqrt(2.0f  * factor2);
    for (i = 1u; i < DCT2D_MATRIX_N; i++)
    {
        for (j = 0u; j < DCT2D_MATRIX_N; j++)
        {
            dct2d_A_matrix_f32[i][j] = factor1 * powerquad_cos(i*(2*j+1)*(0.5f * PI * factor2));
        }
    }

    PQ_SetFormat(POWERQUAD, kPQ_CP_MTX, kPQ_Float);
    PQ_MatrixInversion(POWERQUAD, length, dct2d_A_matrix_f32, powerquad_matrix_inv_tmp, dct2d_A_inv_matrix_f32);
    PQ_WaitDone(POWERQUAD);

#if 0
    timer_stop(app_timer_ticks);

    printf("matrix_A:\r\n");
    print_matrix_f32((float *)dct2d_A_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    printf("matrix_A_inv:\r\n");
    print_matrix_f32((float *)dct2d_A_inv_matrix, DCT2D_MATRIX_N, DCT2D_MATRIX_N);
    printf("\r\n");

    printf(" * %d us\r\n", app_timer_ticks / TIMER_TICKS_PER_US);
#endif
}

void dct2d_powerquad_dct_f32(float32_t *input_matrix, float32_t *output_matrix)
{
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);

    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_matrix_f32, input_matrix, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_inv_matrix_f32, output_matrix);
    PQ_WaitDone(POWERQUAD);
}

void dct2d_powerquad_idct_f32(float32_t *input_matrix, float32_t *output_matrix)
{
    uint32_t length = POWERQUAD_MAKE_MATRIX_LEN(DCT2D_MATRIX_N, DCT2D_MATRIX_N, DCT2D_MATRIX_N);

    PQ_MatrixMultiplication(POWERQUAD, length, dct2d_A_inv_matrix_f32, input_matrix, (void *)POWERQUAD_PRIVATE_RAM_BASE);
    PQ_WaitDone(POWERQUAD);
    PQ_MatrixMultiplication(POWERQUAD, length, (void *)POWERQUAD_PRIVATE_RAM_BASE, dct2d_A_matrix_f32, output_matrix);
    PQ_WaitDone(POWERQUAD);
}

float powerquad_inv(float input)
{
    float output;
    PQ_InvF32(&input, &output);
    return output;
}

float powerquad_cos(float input)
{
    float output;
    PQ_CosF32(&input, &output);
    return output;
}

float powerquad_sqrt(float input)
{
    float output;
    PQ_SqrtF32(&input, &output);
    return output;
}

void print_matrix_f32(float *matrix, uint32_t row_num, uint32_t col_num)
{
//#if APP_ENABLE_PRINT_RESULT
    for (uint32_t i = 0u; i < row_num; i++)
    {
        for (uint32_t j = 0u; j < col_num; j++)
        {
            printf("%10.5f ", *matrix++);
        }
        printf("\r\n");
    }
//#endif /* APP_ENABLE_PRINT_RESULT */
}

/* EOF. */

